import {Token, Tokenizer} from "../types.js";
import {LlamaText} from "../utils/LlamaText.js";
import {tokenizeInput} from "../utils/tokenizeInput.js";
import type {LlamaModel} from "./LlamaModel/LlamaModel.js";

/**
 * @see [Using Token Bias](https://node-llama-cpp.withcat.ai/guide/token-bias) tutorial
 */
export class TokenBias {
    /** @internal */ public readonly _tokenizer: Tokenizer;
    /** @internal */ public readonly _biases = new Map<Token, number>();

    public constructor(tokenizer: Tokenizer) {
        this._tokenizer = tokenizer;
    }

    /**
     * Adjust the bias of the given token(s).
     *
     * If a text is provided, the bias will be applied to each individual token in the text.
     *
     * Setting a bias to `"never"` will prevent the token from being generated, unless it is required to comply with a grammar.
     *
     * Setting the bias of the EOS or EOT tokens to `"never"` has no effect and will be ignored.
     * @param input - The token(s) to apply the bias to
     * @param bias - The probability bias to apply to the token(s).
     *
     * Setting to a positive number increases the probability of the token(s) being generated.
     *
     * Setting to a negative number decreases the probability of the token(s) being generated.
     *
     * Setting to `0` has no effect.
     *
     * For example, setting to `0.5` will increase the probability of the token(s) being generated by 50%.
     * Setting to `-0.5` will decrease the probability of the token(s) being generated by 50%.
     *
     * Setting to `"never"` will prevent the token from being generated, unless it is required to comply with a grammar.
     *
     * Try to play around with values between `0.9` and `-0.9` to see what works for your use case.
     */
    public set(input: Token | Token[] | string | LlamaText, bias: "never" | number | {logit: number}) {
        const resolvedLogit = bias === "never"
            ? -Infinity
            : typeof bias === "number"
                ? probabilityToLogit(bias)
                : bias.logit;

        for (const token of tokenizeInput(input, this._tokenizer)) {
            if (this._tokenizer.isEogToken(token))
                continue;

            this._biases.set(token, resolvedLogit);
        }

        for (const token of tokenizeInput(input, this._tokenizer, "trimLeadingSpace")) {
            if (this._tokenizer.isEogToken(token))
                continue;

            this._biases.set(token, resolvedLogit);
        }

        return this;
    }

    public static for(modelOrTokenizer: LlamaModel | Tokenizer) {
        if ((modelOrTokenizer as LlamaModel).tokenizer != null)
            return new TokenBias((modelOrTokenizer as LlamaModel).tokenizer);

        return new TokenBias(modelOrTokenizer as Tokenizer);
    }
}

function probabilityToLogit(probability: number) {
    if (probability <= -1)
        return -Infinity;
    else if (probability >= 1)
        return Infinity;
    else if (probability === 0)
        return 0;

    return Math.log(probability / (1 - probability));
}
